/*
 * Copyright 2012 The Netty Project
 *
 * The Netty Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package org.infinispan.client.hotrod.impl.transport.netty;

import java.util.concurrent.TimeUnit;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.Channel.Unsafe;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundBuffer;
import io.netty.channel.ChannelPromise;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.handler.timeout.WriteTimeoutHandler;
import io.netty.util.concurrent.Future;
import io.netty.util.internal.ObjectUtil;

/**
 * Triggers an {@link IdleStateEvent} when a {@link Channel} has not performed
 * read, write, or both operation for a while.
 *
 * <h3>Supported idle states</h3>
 * <table border="1">
 * <tr>
 * <th>Property</th><th>Meaning</th>
 * </tr>
 * <tr>
 * <td>{@code readerIdleTime}</td>
 * <td>an {@link IdleStateEvent} whose state is {@link IdleState#READER_IDLE}
 *     will be triggered when no read was performed for the specified period of
 *     time.  Specify {@code 0} to disable.</td>
 * </tr>
 * <tr>
 * <td>{@code writerIdleTime}</td>
 * <td>an {@link IdleStateEvent} whose state is {@link IdleState#WRITER_IDLE}
 *     will be triggered when no write was performed for the specified period of
 *     time.  Specify {@code 0} to disable.</td>
 * </tr>
 * <tr>
 * <td>{@code allIdleTime}</td>
 * <td>an {@link IdleStateEvent} whose state is {@link IdleState#ALL_IDLE}
 *     will be triggered when neither read nor write was performed for the
 *     specified period of time.  Specify {@code 0} to disable.</td>
 * </tr>
 * </table>
 *
 * <pre>
 * // An example that sends a ping message when there is no outbound traffic
 * // for 30 seconds.  The connection is closed when there is no inbound traffic
 * // for 60 seconds.
 *
 * public class MyChannelInitializer extends {@link ChannelInitializer}&lt;{@link Channel}&gt; {
 *     {@code @Override}
 *     public void initChannel({@link Channel} channel) {
 *         channel.pipeline().addLast("idleStateHandler", new {@link IdleStateHandler}(60, 30, 0));
 *         channel.pipeline().addLast("myHandler", new MyHandler());
 *     }
 * }
 *
 * // Handler should handle the {@link IdleStateEvent} triggered by {@link IdleStateHandler}.
 * public class MyHandler extends {@link ChannelDuplexHandler} {
 *     {@code @Override}
 *     public void userEventTriggered({@link ChannelHandlerContext} ctx, {@link Object} evt) throws {@link Exception} {
 *         if (evt instanceof {@link IdleStateEvent}) {
 *             {@link IdleStateEvent} e = ({@link IdleStateEvent}) evt;
 *             if (e.state() == {@link IdleState}.READER_IDLE) {
 *                 ctx.close();
 *             } else if (e.state() == {@link IdleState}.WRITER_IDLE) {
 *                 ctx.writeAndFlush(new PingMessage());
 *             }
 *         }
 *     }
 * }
 *
 * {@link ServerBootstrap} bootstrap = ...;
 * ...
 * bootstrap.childHandler(new MyChannelInitializer());
 * ...
 * </pre>
 *
 * @see ReadTimeoutHandler
 * @see WriteTimeoutHandler
 */
// This class was copied from netty 4.1.109 and modified so write just writes directly to the timeout
// instead of registering a write listener - instance variables are private so couldn't extend
public class IdleStateHandlerNoUnvoid extends ChannelDuplexHandler {
   private static final long MIN_TIMEOUT_NANOS = TimeUnit.MILLISECONDS.toNanos(1);

   private final boolean observeOutput;
   private final long readerIdleTimeNanos;
   private final long writerIdleTimeNanos;
   private final long allIdleTimeNanos;

   private Future<?> readerIdleTimeout;
   private long lastReadTime;
   private boolean firstReaderIdleEvent = true;

   private Future<?> writerIdleTimeout;
   private long lastWriteTime;
   private boolean firstWriterIdleEvent = true;

   private Future<?> allIdleTimeout;
   private boolean firstAllIdleEvent = true;

   private byte state;
   private static final byte ST_INITIALIZED = 1;
   private static final byte ST_DESTROYED = 2;

   private boolean reading;

   private long lastChangeCheckTimeStamp;
   private int lastMessageHashCode;
   private long lastPendingWriteBytes;
   private long lastFlushProgress;

   /**
    * Creates a new instance firing {@link IdleStateEvent}s.
    *
    * @param readerIdleTimeSeconds
    *        an {@link IdleStateEvent} whose state is {@link IdleState#READER_IDLE}
    *        will be triggered when no read was performed for the specified
    *        period of time.  Specify {@code 0} to disable.
    * @param writerIdleTimeSeconds
    *        an {@link IdleStateEvent} whose state is {@link IdleState#WRITER_IDLE}
    *        will be triggered when no write was performed for the specified
    *        period of time.  Specify {@code 0} to disable.
    * @param allIdleTimeSeconds
    *        an {@link IdleStateEvent} whose state is {@link IdleState#ALL_IDLE}
    *        will be triggered when neither read nor write was performed for
    *        the specified period of time.  Specify {@code 0} to disable.
    */
   public IdleStateHandlerNoUnvoid(
         int readerIdleTimeSeconds,
         int writerIdleTimeSeconds,
         int allIdleTimeSeconds) {

      this(readerIdleTimeSeconds, writerIdleTimeSeconds, allIdleTimeSeconds,
            TimeUnit.SECONDS);
   }

   /**
    * @see #IdleStateHandlerNoUnvoid(boolean, long, long, long, TimeUnit)
    */
   public IdleStateHandlerNoUnvoid(
         long readerIdleTime, long writerIdleTime, long allIdleTime,
         TimeUnit unit) {
      this(false, readerIdleTime, writerIdleTime, allIdleTime, unit);
   }

   /**
    * Creates a new instance firing {@link IdleStateEvent}s.
    *
    * @param observeOutput
    *        whether or not the consumption of {@code bytes} should be taken into
    *        consideration when assessing write idleness. The default is {@code false}.
    * @param readerIdleTime
    *        an {@link IdleStateEvent} whose state is {@link IdleState#READER_IDLE}
    *        will be triggered when no read was performed for the specified
    *        period of time.  Specify {@code 0} to disable.
    * @param writerIdleTime
    *        an {@link IdleStateEvent} whose state is {@link IdleState#WRITER_IDLE}
    *        will be triggered when no write was performed for the specified
    *        period of time.  Specify {@code 0} to disable.
    * @param allIdleTime
    *        an {@link IdleStateEvent} whose state is {@link IdleState#ALL_IDLE}
    *        will be triggered when neither read nor write was performed for
    *        the specified period of time.  Specify {@code 0} to disable.
    * @param unit
    *        the {@link TimeUnit} of {@code readerIdleTime},
    *        {@code writeIdleTime}, and {@code allIdleTime}
    */
   public IdleStateHandlerNoUnvoid(boolean observeOutput,
                           long readerIdleTime, long writerIdleTime, long allIdleTime,
                           TimeUnit unit) {
      ObjectUtil.checkNotNull(unit, "unit");

      this.observeOutput = observeOutput;

      if (readerIdleTime <= 0) {
         readerIdleTimeNanos = 0;
      } else {
         readerIdleTimeNanos = Math.max(unit.toNanos(readerIdleTime), MIN_TIMEOUT_NANOS);
      }
      if (writerIdleTime <= 0) {
         writerIdleTimeNanos = 0;
      } else {
         writerIdleTimeNanos = Math.max(unit.toNanos(writerIdleTime), MIN_TIMEOUT_NANOS);
      }
      if (allIdleTime <= 0) {
         allIdleTimeNanos = 0;
      } else {
         allIdleTimeNanos = Math.max(unit.toNanos(allIdleTime), MIN_TIMEOUT_NANOS);
      }
   }

   /**
    * Return the readerIdleTime that was given when instance this class in milliseconds.
    *
    */
   public long getReaderIdleTimeInMillis() {
      return TimeUnit.NANOSECONDS.toMillis(readerIdleTimeNanos);
   }

   /**
    * Return the writerIdleTime that was given when instance this class in milliseconds.
    *
    */
   public long getWriterIdleTimeInMillis() {
      return TimeUnit.NANOSECONDS.toMillis(writerIdleTimeNanos);
   }

   /**
    * Return the allIdleTime that was given when instance this class in milliseconds.
    *
    */
   public long getAllIdleTimeInMillis() {
      return TimeUnit.NANOSECONDS.toMillis(allIdleTimeNanos);
   }

   @Override
   public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
      if (ctx.channel().isActive() && ctx.channel().isRegistered()) {
         // channelActive() event has been fired already, which means this.channelActive() will
         // not be invoked. We have to initialize here instead.
         initialize(ctx);
      } else {
         // channelActive() event has not been fired yet.  this.channelActive() will be invoked
         // and initialization will occur there.
      }
   }

   @Override
   public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
      destroy();
   }

   @Override
   public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
      // Initialize early if channel is active already.
      if (ctx.channel().isActive()) {
         initialize(ctx);
      }
      super.channelRegistered(ctx);
   }

   @Override
   public void channelActive(ChannelHandlerContext ctx) throws Exception {
      // This method will be invoked only if this handler was added
      // before channelActive() event is fired.  If a user adds this handler
      // after the channelActive() event, initialize() will be called by beforeAdd().
      initialize(ctx);
      super.channelActive(ctx);
   }

   @Override
   public void channelInactive(ChannelHandlerContext ctx) throws Exception {
      destroy();
      super.channelInactive(ctx);
   }

   @Override
   public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
      if (readerIdleTimeNanos > 0 || allIdleTimeNanos > 0) {
         reading = true;
         firstReaderIdleEvent = firstAllIdleEvent = true;
      }
      ctx.fireChannelRead(msg);
   }

   @Override
   public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
      if ((readerIdleTimeNanos > 0 || allIdleTimeNanos > 0) && reading) {
         lastReadTime = ticksInNanos();
         reading = false;
      }
      ctx.fireChannelReadComplete();
   }

   @Override
   public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
      // Allow writing with void promise if handler is only configured for read timeout events.
      if (writerIdleTimeNanos > 0 || allIdleTimeNanos > 0) {
         // This part was changed to not unvoid the promise and add a listener
         lastWriteTime = ticksInNanos();
         firstWriterIdleEvent = firstAllIdleEvent = true;
      }
      ctx.write(msg, promise);
   }

   /**
    * Reset the read timeout. As this handler is not thread-safe, this method <b>must</b> be called on the event loop.
    */
   public void resetReadTimeout() {
      if (readerIdleTimeNanos > 0 || allIdleTimeNanos > 0) {
         lastReadTime = ticksInNanos();
         reading = false;
      }
   }

   /**
    * Reset the write timeout. As this handler is not thread-safe, this method <b>must</b> be called on the event loop.
    */
   public void resetWriteTimeout() {
      if (writerIdleTimeNanos > 0 || allIdleTimeNanos > 0) {
         lastWriteTime = ticksInNanos();
      }
   }

   private void initialize(ChannelHandlerContext ctx) {
      // Avoid the case where destroy() is called before scheduling timeouts.
      // See: https://github.com/netty/netty/issues/143
      switch (state) {
         case 1:
         case 2:
            return;
         default:
            break;
      }

      state = ST_INITIALIZED;
      initOutputChanged(ctx);

      lastReadTime = lastWriteTime = ticksInNanos();
      if (readerIdleTimeNanos > 0) {
         readerIdleTimeout = schedule(ctx, new ReaderIdleTimeoutTask(ctx),
               readerIdleTimeNanos, TimeUnit.NANOSECONDS);
      }
      if (writerIdleTimeNanos > 0) {
         writerIdleTimeout = schedule(ctx, new WriterIdleTimeoutTask(ctx),
               writerIdleTimeNanos, TimeUnit.NANOSECONDS);
      }
      if (allIdleTimeNanos > 0) {
         allIdleTimeout = schedule(ctx, new AllIdleTimeoutTask(ctx),
               allIdleTimeNanos, TimeUnit.NANOSECONDS);
      }
   }

   /**
    * This method is visible for testing!
    */
   long ticksInNanos() {
      return System.nanoTime();
   }

   /**
    * This method is visible for testing!
    */
   Future<?> schedule(ChannelHandlerContext ctx, Runnable task, long delay, TimeUnit unit) {
      return ctx.executor().schedule(task, delay, unit);
   }

   private void destroy() {
      state = ST_DESTROYED;

      if (readerIdleTimeout != null) {
         readerIdleTimeout.cancel(false);
         readerIdleTimeout = null;
      }
      if (writerIdleTimeout != null) {
         writerIdleTimeout.cancel(false);
         writerIdleTimeout = null;
      }
      if (allIdleTimeout != null) {
         allIdleTimeout.cancel(false);
         allIdleTimeout = null;
      }
   }

   /**
    * Is called when an {@link IdleStateEvent} should be fired. This implementation calls
    * {@link ChannelHandlerContext#fireUserEventTriggered(Object)}.
    */
   protected void channelIdle(ChannelHandlerContext ctx, IdleStateEvent evt) throws Exception {
      ctx.fireUserEventTriggered(evt);
   }

   /**
    * Returns a {@link IdleStateEvent}.
    */
   protected IdleStateEvent newIdleStateEvent(IdleState state, boolean first) {
      switch (state) {
         case ALL_IDLE:
            return first ? IdleStateEvent.FIRST_ALL_IDLE_STATE_EVENT : IdleStateEvent.ALL_IDLE_STATE_EVENT;
         case READER_IDLE:
            return first ? IdleStateEvent.FIRST_READER_IDLE_STATE_EVENT : IdleStateEvent.READER_IDLE_STATE_EVENT;
         case WRITER_IDLE:
            return first ? IdleStateEvent.FIRST_WRITER_IDLE_STATE_EVENT : IdleStateEvent.WRITER_IDLE_STATE_EVENT;
         default:
            throw new IllegalArgumentException("Unhandled: state=" + state + ", first=" + first);
      }
   }

   /**
    * @see #hasOutputChanged(ChannelHandlerContext, boolean)
    */
   private void initOutputChanged(ChannelHandlerContext ctx) {
      if (observeOutput) {
         Channel channel = ctx.channel();
         Unsafe unsafe = channel.unsafe();
         ChannelOutboundBuffer buf = unsafe.outboundBuffer();

         if (buf != null) {
            lastMessageHashCode = System.identityHashCode(buf.current());
            lastPendingWriteBytes = buf.totalPendingWriteBytes();
            lastFlushProgress = buf.currentProgress();
         }
      }
   }

   /**
    * Returns {@code true} if and only if the {@link IdleStateHandler} was constructed
    * with {@link #observeOutput} enabled and there has been an observed change in the
    * {@link ChannelOutboundBuffer} between two consecutive calls of this method.
    *
    * https://github.com/netty/netty/issues/6150
    */
   private boolean hasOutputChanged(ChannelHandlerContext ctx, boolean first) {
      if (observeOutput) {

         // We can take this shortcut if the ChannelPromises that got passed into write()
         // appear to complete. It indicates "change" on message level and we simply assume
         // that there's change happening on byte level. If the user doesn't observe channel
         // writability events then they'll eventually OOME and there's clearly a different
         // problem and idleness is least of their concerns.
         if (lastChangeCheckTimeStamp != lastWriteTime) {
            lastChangeCheckTimeStamp = lastWriteTime;

            // But this applies only if it's the non-first call.
            if (!first) {
               return true;
            }
         }

         Channel channel = ctx.channel();
         Unsafe unsafe = channel.unsafe();
         ChannelOutboundBuffer buf = unsafe.outboundBuffer();

         if (buf != null) {
            int messageHashCode = System.identityHashCode(buf.current());
            long pendingWriteBytes = buf.totalPendingWriteBytes();

            if (messageHashCode != lastMessageHashCode || pendingWriteBytes != lastPendingWriteBytes) {
               lastMessageHashCode = messageHashCode;
               lastPendingWriteBytes = pendingWriteBytes;

               if (!first) {
                  return true;
               }
            }

            long flushProgress = buf.currentProgress();
            if (flushProgress != lastFlushProgress) {
               lastFlushProgress = flushProgress;
               return !first;
            }
         }
      }

      return false;
   }

   private abstract static class AbstractIdleTask implements Runnable {

      private final ChannelHandlerContext ctx;

      AbstractIdleTask(ChannelHandlerContext ctx) {
         this.ctx = ctx;
      }

      @Override
      public void run() {
         if (!ctx.channel().isOpen()) {
            return;
         }

         run(ctx);
      }

      protected abstract void run(ChannelHandlerContext ctx);
   }

   private final class ReaderIdleTimeoutTask extends AbstractIdleTask {

      ReaderIdleTimeoutTask(ChannelHandlerContext ctx) {
         super(ctx);
      }

      @Override
      protected void run(ChannelHandlerContext ctx) {
         long nextDelay = readerIdleTimeNanos;
         if (!reading) {
            nextDelay -= ticksInNanos() - lastReadTime;
         }

         if (nextDelay <= 0) {
            // Reader is idle - set a new timeout and notify the callback.
            readerIdleTimeout = schedule(ctx, this, readerIdleTimeNanos, TimeUnit.NANOSECONDS);

            boolean first = firstReaderIdleEvent;
            firstReaderIdleEvent = false;

            try {
               IdleStateEvent event = newIdleStateEvent(IdleState.READER_IDLE, first);
               channelIdle(ctx, event);
            } catch (Throwable t) {
               ctx.fireExceptionCaught(t);
            }
         } else {
            // Read occurred before the timeout - set a new timeout with shorter delay.
            readerIdleTimeout = schedule(ctx, this, nextDelay, TimeUnit.NANOSECONDS);
         }
      }
   }

   private final class WriterIdleTimeoutTask extends AbstractIdleTask {

      WriterIdleTimeoutTask(ChannelHandlerContext ctx) {
         super(ctx);
      }

      @Override
      protected void run(ChannelHandlerContext ctx) {

         long lastWriteTime = IdleStateHandlerNoUnvoid.this.lastWriteTime;
         long nextDelay = writerIdleTimeNanos - (ticksInNanos() - lastWriteTime);
         if (nextDelay <= 0) {
            // Writer is idle - set a new timeout and notify the callback.
            writerIdleTimeout = schedule(ctx, this, writerIdleTimeNanos, TimeUnit.NANOSECONDS);

            boolean first = firstWriterIdleEvent;
            firstWriterIdleEvent = false;

            try {
               if (hasOutputChanged(ctx, first)) {
                  return;
               }

               IdleStateEvent event = newIdleStateEvent(IdleState.WRITER_IDLE, first);
               channelIdle(ctx, event);
            } catch (Throwable t) {
               ctx.fireExceptionCaught(t);
            }
         } else {
            // Write occurred before the timeout - set a new timeout with shorter delay.
            writerIdleTimeout = schedule(ctx, this, nextDelay, TimeUnit.NANOSECONDS);
         }
      }
   }

   private final class AllIdleTimeoutTask extends AbstractIdleTask {

      AllIdleTimeoutTask(ChannelHandlerContext ctx) {
         super(ctx);
      }

      @Override
      protected void run(ChannelHandlerContext ctx) {

         long nextDelay = allIdleTimeNanos;
         if (!reading) {
            nextDelay -= ticksInNanos() - Math.max(lastReadTime, lastWriteTime);
         }
         if (nextDelay <= 0) {
            // Both reader and writer are idle - set a new timeout and
            // notify the callback.
            allIdleTimeout = schedule(ctx, this, allIdleTimeNanos, TimeUnit.NANOSECONDS);

            boolean first = firstAllIdleEvent;
            firstAllIdleEvent = false;

            try {
               if (hasOutputChanged(ctx, first)) {
                  return;
               }

               IdleStateEvent event = newIdleStateEvent(IdleState.ALL_IDLE, first);
               channelIdle(ctx, event);
            } catch (Throwable t) {
               ctx.fireExceptionCaught(t);
            }
         } else {
            // Either read or write occurred before the timeout - set a new
            // timeout with shorter delay.
            allIdleTimeout = schedule(ctx, this, nextDelay, TimeUnit.NANOSECONDS);
         }
      }
   }
}
